# -*- coding: utf-8 -*-  
'''
训练gsdmm模型

Created on 2021年9月5日
@author: luoyi
'''
import sys
import os
#    取项目根目录
ROOT_PATH = os.path.abspath(os.path.dirname(__file__)).split('tbert')[0]
ROOT_PATH = ROOT_PATH + "tbert"
sys.path.append(ROOT_PATH)

import utils.conf as conf

from utils.iexicon import LiteWordsWarehouse
from data.sohu_thuc_news.lda_gsdmm_dataset import LdaGSDmmPreDataset
from models.gsdmm.nets_np import GSDMM


#    初始化词库
print('初始化词库.')
LiteWordsWarehouse.instance().load_pkl(word_id_path=conf.DATASET_SOHU_THUCNEWS.get_word_id_path(), 
                                       word_frequency_path=conf.DATASET_SOHU_THUCNEWS.get_word_frequency_path())

#    初始化数据集
gsdmm_ds = LdaGSDmmPreDataset()

#    初始化模型
gsdmm = GSDMM(save_doc_inteval=200, save_epoch_inteval=1)

#    训练模型
gsdmm.training(gsdmm_ds, epochs=600)

